{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4d252bab-f488-489f-a3a3-a8869cf99dcb",
   "metadata": {
    "tags": []
   },
   "source": [
    "1. 安装最新版本的modelscope和swift"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68bbc6b0-e4e5-4663-a28e-fec1e667427b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install modelscope ms-swift -U\n",
    "!pip install tf-keras==2.16.0 --no-dependencies"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc13d59-415b-4455-a965-49cead0904f7",
   "metadata": {},
   "source": [
    "2. 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9a292aab-04d2-4b68-8507-53f1ea22bb9f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-11-01T08:24:57.831096Z",
     "iopub.status.busy": "2024-11-01T08:24:57.830826Z",
     "iopub.status.idle": "2024-11-01T08:25:07.688996Z",
     "shell.execute_reply": "2024-11-01T08:25:07.688492Z",
     "shell.execute_reply.started": "2024-11-01T08:24:57.831080Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from modelscope import MsDataset\n",
    "dataset = MsDataset.load('swift/classical_chinese_translate')\n",
    "dataset = dataset['train']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe841e73-3f85-4a4d-879c-0b9064fe9e98",
   "metadata": {},
   "source": [
    "3. 查看数据集内容"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "817f4c82-cf51-4504-8866-09eff41fe732",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(dataset)\n",
    "print(dataset['conversations'][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac948dfd-f290-43c4-bf4f-7ece4883ccad",
   "metadata": {},
   "source": [
    "4. 加载模型、分词器、模板"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "231ff476-df34-470b-8b73-4bd24e2f298a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from modelscope import AutoModelForCausalLM\n",
    "from modelscope import AutoTokenizer\n",
    "from swift.llm import get_template, TemplateType\n",
    "tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-7B-Instruct', trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-7B-Instruct', torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)\n",
    "print(model)\n",
    "print(tokenizer('I like AI'))\n",
    "\n",
    "\n",
    "template = get_template(TemplateType.qwen2_5, tokenizer, max_length=400)\n",
    "ret = template.encode({'query': 'what is your hobby?', 'response': 'I like AI'})\n",
    "print(ret)\n",
    "tokenizer.decode(ret[0]['input_ids'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dd66f8b-921c-4433-be8d-122bc182920c",
   "metadata": {},
   "source": [
    "5. 预处理数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41eebe2-7009-41cf-a936-5d4b471daba4",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from swift.llm.utils.preprocess import ConversationsPreprocessor\n",
    "ds = ConversationsPreprocessor()(dataset)\n",
    "print(ds[0])\n",
    "\n",
    "def encode(example):\n",
    "    example, kwargs = template.encode(example)\n",
    "    if 'input_ids' not in example:\n",
    "        return {\n",
    "            'input_ids': None,\n",
    "            'labels': None,\n",
    "        }\n",
    "    return example\n",
    "\n",
    "ds = ds.select(range(300)).map(encode).filter(lambda e: e.get('input_ids'))\n",
    "ds = ds.train_test_split(test_size=0.1)\n",
    "\n",
    "train_dataset, val_dataset = ds['train'], ds['test']\n",
    "print('===========================================')\n",
    "print(train_dataset[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8eb90b4e-7e65-48d7-9f62-d9b226918618",
   "metadata": {},
   "source": [
    "6. 加载LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9cf58c93-a3e6-4672-80d9-06acc25bd9e8",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2024-11-01T08:25:51.704754Z",
     "iopub.status.busy": "2024-11-01T08:25:51.704328Z",
     "iopub.status.idle": "2024-11-01T08:25:52.427483Z",
     "shell.execute_reply": "2024-11-01T08:25:52.426958Z",
     "shell.execute_reply.started": "2024-11-01T08:25:51.704723Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from swift import Swift, LoraConfig\n",
    "\n",
    "\n",
    "lora_config = LoraConfig(\n",
    "                r=8,\n",
    "                target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],\n",
    "                lora_alpha=32,\n",
    "                lora_dropout=0.05)\n",
    "model = Swift.prepare_model(model, lora_config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a06c9e0-24ff-49a7-b901-54d800248407",
   "metadata": {},
   "source": [
    "7. 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "add78662-d30f-48b7-af7c-358840af1433",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# A100 18G memory\n",
    "from swift import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
    "import torch\n",
    "\n",
    "\n",
    "train_args = Seq2SeqTrainingArguments(\n",
    "    output_dir='output',\n",
    "    learning_rate=1e-4,\n",
    "    num_train_epochs=1,\n",
    "    eval_steps=5,\n",
    "    save_steps=5,\n",
    "    evaluation_strategy='no',\n",
    "    save_strategy='steps',\n",
    "    dataloader_num_workers=4,\n",
    "    per_device_train_batch_size=1,\n",
    "    gradient_accumulation_steps=16,\n",
    "    logging_steps=2,\n",
    ")\n",
    "\n",
    "print(train_dataset[0])\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    args=train_args,\n",
    "    data_collator=template.data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    tokenizer=tokenizer)\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "741ed1fd-1623-4221-a7c3-efdd37bc57da",
   "metadata": {},
   "source": [
    "8. 看看存了什么"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e1a5e23-632d-4487-8a5c-fcae84311f51",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!ls output/checkpoint-11"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "799089b3-5458-4b54-9865-131fca081b26",
   "metadata": {},
   "source": [
    "9. 推理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00f72b82-e856-48cd-9be5-38afee347b48",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from modelscope import GenerationConfig\n",
    "query = '你喜欢什么'\n",
    "inputs, kwargs = template.encode({'query': query})\n",
    "print(inputs)\n",
    "generation_config = GenerationConfig(max_new_tokens=512, top_p=0.7, temperature=0.3)\n",
    "inputs['input_ids'] = torch.tensor(inputs['input_ids'])[None].cuda()\n",
    "generate_ids = model.generate(generation_config=generation_config, **inputs)\n",
    "print(generate_ids)\n",
    "print(tokenizer.decode(generate_ids[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95f72cea-a972-4e0c-aad4-4a6f1a5aca25",
   "metadata": {},
   "source": [
    "10. 界面"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbf4c631-6d3f-4813-8380-96fbf632dac9",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install gradio==3.50.2\n",
    "!swift web-ui"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
